{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Achari Berrada Youssef\n",
    "# This is an implementation of the Multinomial DBM Architecture.\n",
    "\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import config\n",
    "import os\n",
    "\n",
    "from yadlt.utils import utilities\n",
    "\n",
    "# Trained from the notebook run_flickr_image_dbm.ipynb\n",
    "image_layer1_W = np.load(\"image_dbm_layer_1_W.npy\")\n",
    "image_layer1_b = np.load(\"image_dbm_layer_1_b.npy\")\n",
    "image_layer2_W = np.load(\"image_dbm_layer_2_W.npy\")\n",
    "image_layer2_b = np.load(\"image_dbm_layer_2_b.npy\")\n",
    "\n",
    "# Trained from the notebook run_flickr_text_dbm.ipynb\n",
    "text_layer1_W = np.load(\"text_dbm_layer_1_W.npy\")\n",
    "text_layer1_b = np.load(\"text_dbm_layer_1_b.npy\")\n",
    "text_layer2_W = np.load(\"text_dbm_layer_2_W.npy\")\n",
    "text_layer2_b = np.load(\"text_dbm_layer_2_b.npy\")\n",
    "\n",
    "\n",
    "\n",
    "# Input placeholders\n",
    "img_input = tf.placeholder(tf.float32, [None, 3857], name=\"image-input\")\n",
    "txt_input = tf.placeholder(tf.float32, [None, 2000], name=\"text-input\")\n",
    "\n",
    "\n",
    "# ################# #\n",
    "# Image Forward DBM #\n",
    "# ################# #\n",
    "\n",
    "img_W1 = tf.Variable(image_layer1_W)\n",
    "img_b1 = tf.Variable(image_layer1_b)\n",
    "\n",
    "img_W2 = tf.Variable(image_layer2_W)\n",
    "img_b2 = tf.Variable(image_layer2_b)\n",
    "\n",
    "img_layer1 = tf.add(tf.matmul(img_input, img_W1), img_b1)\n",
    "img_layer1 = tf.nn.sigmoid(img_layer1)\n",
    "\n",
    "img_layer2 = tf.add(tf.matmul(img_layer1, img_W2), img_b2)\n",
    "img_layer2 = tf.nn.sigmoid(img_layer2)\n",
    "\n",
    "\n",
    "\n",
    "# ################ #\n",
    "# Text Forward DBM #\n",
    "# ################ #\n",
    "\n",
    "txt_W1 = tf.Variable(text_layer1_W)\n",
    "txt_b1 = tf.Variable(text_layer1_b)\n",
    "\n",
    "txt_W2 = tf.Variable(text_layer2_W)\n",
    "txt_b2 = tf.Variable(text_layer2_b)\n",
    "\n",
    "txt_layer1 = tf.add(tf.matmul(txt_input, txt_W1), txt_b1)\n",
    "txt_layer1 = tf.nn.sigmoid(txt_layer1)\n",
    "\n",
    "txt_layer2 = tf.add(tf.matmul(txt_layer1, txt_W2), txt_b2)\n",
    "txt_layer2 = tf.nn.sigmoid(txt_layer2)\n",
    "\n",
    "\n",
    "# ############## #\n",
    "# Multimodal DBM #\n",
    "# ############## #\n",
    "\n",
    "joint_representation_units = 512  # number of units in the joint representation layer\n",
    "\n",
    "multi_W = tf.Variable(tf.truncated_normal(shape=[2048, joint_representation_units], stddev=0.1), name='multimodal-weights')\n",
    "multi_b = tf.Variable(tf.constant(0.1, shape=[joint_representation_units]), name='multimodal-bias')\n",
    "\n",
    "multi_input = tf.concat([img_layer2, txt_layer2], 0)\n",
    "multi_output = tf.nn.sigmoid(tf.add(tf.matmul(multi_input, multi_W), multi_b))\n",
    "\n",
    "binary_output = utilities.sample_prob(hprobs, np.random.rand(data.shape[0], joint_representation_units))\n",
    "\n",
    "# ################## #\n",
    "# Image Backward DBM #\n",
    "# ################## #\n",
    "\n",
    "img_layer2b = tf.add(tf.matmul(binary_output, img_W2.T), img_b2)\n",
    "img_layer2b = tf.nn.sigmoid(img_layer2b)\n",
    "\n",
    "img_layer1b = tf.add(tf.matmul(img_layer2b, img_W1.T), img_b1)\n",
    "img_layer1b = tf.nn.sigmoid(img_layer1b)\n",
    "\n",
    "# ################# #\n",
    "# Text Backward DBM #\n",
    "# ################# #\n",
    "\n",
    "txt_layer2b = tf.add(tf.matmul(binary_output, txt_W2.T), txt_b2)\n",
    "txt_layer2b = tf.nn.sigmoid(txt_layer2b)\n",
    "\n",
    "txt_layer1b = tf.add(tf.matmul(txt_layer2b, txt_W1.T), txt_b1)\n",
    "txt_layer1b = tf.nn.sigmoid(txt_layer1b)\n",
    "\n",
    "# Now can use img_layer1b (which should be 3587) and txt_layer1b (which should be 2000) to compute the loss\n",
    "# function. For mean squared error cost:\n",
    "\n",
    "img_cost = tf.sqrt(tf.reduce_mean(tf.square(tf.sub(img_dataset, img_layer1b))))\n",
    "txt_cost = tf.sqrt(tf.reduce_mean(tf.square(tf.sub(txt_dataset, txt_layer1b))))\n",
    "\n",
    "# Optimizer\n",
    "\n",
    "img_train_step = tf.train.MomentumOptimizer(learning_rate, momentum).minimize(img_cost)\n",
    "txt_train_step = tf.train.MomentumOptimizer(learning_rate, momentum).minimize(txt_cost)\n",
    "\n",
    "# ####################### #\n",
    "# Code to train the model #\n",
    "# ####################### #\n",
    "\n",
    "# train_set and train_ref should be loaded, maybe from the labeled data directory\n",
    "labeled1 = np.load(os.path.join(config.flickr_labeled_path, \"combined-00001-of-00100.npy\"))\n",
    "labeled2 = np.load(os.path.join(config.flickr_labeled_path, \"combined-00002-of-00100.npy\"))\n",
    "labeled3 = np.load(os.path.join(config.flickr_labeled_path, \"combined-00003_0-of-00100.npy\"))\n",
    "labeled = np.concatenate((labeled1, labeled2, labeled3), axis=0)\n",
    "\n",
    "num_epochs = 5\n",
    "batch_size = 64\n",
    "\n",
    "shuff = zip(train_set, train_ref)\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    for i in range(num_epochs):\n",
    "        np.random.shuffle(shuff)\n",
    "        batches = [_ for _ in utilities.gen_batches(shuff, batch_size)]\n",
    "\n",
    "        for batch in batches:\n",
    "            x_batch, y_batch = zip(*batch)\n",
    "            # run image train\n",
    "            sess.run(img_train_step, feed_dict={img_input: x_batch, image_ref: y_batch})\n",
    "            # run text train\n",
    "            sess.run(txt_train_step, feed_dict={txt_input: x_batch, text_ref: y_batch})\n",
    "\n",
    "        if validation_set is not None:\n",
    "            img_loss = sess.run(img_cost, feed_dict={img_input: img_validation_input, img_ref: img_validation_ref})\n",
    "            txt_loss = sess.run(txt_cost, feed_dict={txt_input: txt_validation_input, txt_ref: txt_validation_ref})\n",
    "            print(\"Image loss: %f\" % (img_loss))\n",
    "            print(\"Text loss: %f\" % (txt_loss))\n",
    "\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [conda env:py27]",
   "language": "python",
   "name": "conda-env-py27-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
